from src.tasks.base_task.dataloader import BaseDataLoader
import numpy as np


        


class AnBn:
    def __init__(self,k,l):
        assert k<=l
        self.k=k
        self.l=l
        self.obs=np.eye(3)
        self.reset()
        
        
    def reset(self):
        self.n=np.random.randint(self.k,self.l+1)
        self.counter=0
        self.mode=0 #Mode a 
    
    @property
    def input_size(self):
        return self.obs.shape[0]
    
    def step(self):
        if self.counter<self.n:
            self.counter+=1
            if self.mode==0:
                return self.obs[0].copy()
            else:
                return self.obs[1].copy()
        else:
            if self.mode==0:
                self.mode=1
                self.counter=0
            else:
                self.reset()
            return self.obs[2].copy()


class DistantBrackets:
    def __init__(self,s,k,a):
        self.s=s
        self.k=k
        self.a=a
        self.obs=np.eye(3+a)
        self.reset()

    @property
    def input_size(self):
        return self.obs.shape[0]
    
    def reset(self):
        self.last_state=7
        self.history=[]
        self.s_counter=0
        self.k_counter=0
        
    def step(self):
        if self.last_state==7:
            self.last_state=0
        elif self.last_state==0:
            self.last_state=1
        elif self.last_state==1:
            if self.s_counter==self.s-1:
                self.s_counter=0
                self.last_state=2
            else:
                self.s_counter+=1
        elif self.last_state==2:
            self.last_state=3
        elif self.last_state==3:
            if self.k_counter==self.k-1:
                self.k_counter=0
                self.last_state=4
            else:
                self.k_counter+=1
        elif self.last_state==4:
            self.last_state=5
        elif self.last_state==5:
            if self.s_counter==self.s-1:
                self.s_counter=0
                self.last_state=6
            else:
                self.s_counter+=1
        elif self.last_state==6:
            self.reset()
            self.last_state=7
        return self.return_obs(self.last_state)
        
    def return_obs(self,state):
        if state==0 or state==4:
            return self.obs[0].copy() #Left [
        elif state==2 or state==6:
            return self.obs[1].copy() #]
        elif state==7:
            return self.obs[2].copy() #\n
        elif state==1:
            out=self.obs[np.random.randint(3,3+self.a)]
            self.history.append(out)
            return out
        elif state==5:
            return self.history[self.s_counter]
        elif state==3:
            return self.obs[np.random.randint(3,3+self.a)]
        
        
        
        